# get predictions from hierarchical lasso
library(tidyverse)
library(data.table)
library(dtplyr)
library(network)

library(hierNet)

require(foreach)
require(doMC)
registerDoMC(cores = 11)

source("pundits_functions.R")

set.seed(1111)

get_yhats <- 
  function(fit,
           mod.cv,
           x,
           y, 
           nfolds=10, 
           trace=0, 
           lamlist = NULL) {
  this.call <- match.call()
  stopifnot(class(fit) == "hierNet.path")
  if(fit$type=="gaussian"){errfun=function(y,yhat){(y-yhat)^2}} 
  if(fit$type=="logistic"){errfun=function(y,yhat){1*(y!=yhat)}} 
  n <- length(y)
      
  folds <- mod.cv$folds
  stopifnot(class(folds)=="list")
  nfolds <- length(folds)
      
  if(is.null(lamlist)){
        lamlist=fit$lamlist
  }
      
  # get whether fit was standardized based on fit$sx and fit$szz...
  if (is.null(fit$mx)) stop("hierNet object was not centered.  hierNet.cv has not been written for this (unusual) case.")
  stand.main <- !is.null(fit$sx)
  stand.int <- !is.null(fit$szz)

  n.lamlist <- length(lamlist)        ### Set up the data structures
  size <- double(n.lamlist)
  err2=matrix(NA,nrow=nfolds,ncol=length(lamlist))
  
  yhats <- list()
  
  for(ii in 1:nfolds) {
    cat("Fold", ii, ":")
    if(fit$type=="gaussian"){
      a <- hierNet.path(x[-folds[[ii]],],y=y[-folds[[ii]]], 
                        lamlist=lamlist, delta=fit$delta, diagonal=fit$diagonal, strong=fit$strong, trace=trace,
                        stand.main=stand.main, stand.int=stand.int,
                        rho=fit$rho, niter=fit$niter, sym.eps=fit$sym.eps, # ADMM parameters (which will be NULL if strong=F)
                        step=fit$step, maxiter=fit$maxiter, backtrack=fit$backtrack, tol=fit$tol) # GG descent params
      
      yhatt=predict(a,newx=x[folds[[ii]],])
      
      yhats[[ii]] <- yhatt
    }
    if(fit$type=="logistic"){
      a <- hierNet.logistic.path(x[-folds[[ii]],],y=y[-folds[[ii]]], 
                                 lamlist=lamlist, delta=fit$delta, diagonal=fit$diagonal, strong=fit$strong,
                                 trace=trace, stand.main=stand.main, stand.int=stand.int,
                                 rho=fit$rho, niter=fit$niter, sym.eps=fit$sym.eps, # ADMM parameters (which will be NULL if strong=F)
                                 step=fit$step, maxiter=fit$maxiter, backtrack=fit$backtrack, tol=fit$tol) # GG descent params                                 
      yhatt=predict(a,newx=x[folds[[ii]],])$yhat
      
      yhats[[ii]] <- yhatt
    }
    
    temp=matrix(y[folds[[ii]]],nrow=length(folds[[ii]]),ncol=n.lamlist)
    err2[ii,]=colMeans(errfun(yhatt,temp))
    cat("\n")
  }
  errm=colMeans(err2)
  errse=sqrt(apply(err2,2,var)/nfolds)
  o=which.min(errm)
  lamhat=lamlist[o]
  oo=errm<= errm[o]+errse[o]
  lamhat.1se=lamlist[oo & lamlist>=lamhat][1]
  
  nonzero=colSums(fit$bp-fit$bn!=0) + apply(fit$th!=0, 3, function(a) sum(diag(a)) + sum((a+t(a)!=0)[upper.tri(a)]))
  obj <- list(lamlist=lamlist, cv.err=errm,cv.se=errse,lamhat=lamhat, lamhat.1se=lamhat.1se,
              nonzero=nonzero, folds=folds,
              yhat = yhats,
              call = this.call)
  class(obj) <- "hierNet.cv"
  obj
}

run_heldout_yhat <- function(input, 
                         outcome, 
                         includevars,
                         fit,
                         mod.cv,
                         lamlist,
                         save_path = "../output/",
                         name,
                         seed = 11111){
  require(hierNet)
  set.seed(seed)
  
  message('shaping input')
  X <- input %>%
    dplyr::select(all_of(includevars)) %>%
    as.matrix()
  
  X_scale <- as.matrix(data.frame(scale(X)))
  
  Y <- outcome
  if(any(is.na(Y))){
    message(paste0("removing ", sum(is.na(Y)), " observations due to missingness"))
    Y <- Y[-which(is.na(Y))]
    X_scale <- X_scale[-which(is.na(Y)),]
  }
   
  yhats <- get_yhats(fit = mod,
                      mod.cv = mod.cv,
                      x = X_scale,
                      y = Y,
                      lamlist = lamlist,
                      trace = 1)
    
  save(yhats, file = paste0(save_path, "modcv_yhats_", name, ".RData"))
}


pundits_agg0 <- readr::read_csv("../data/pundits_parrot_aggregated_imp0.csv")
                                            
ideal.points_d1 <- readr::read_csv("../data/ideal_points_masked.csv")
      
which.na.ts <- which(is.na(ideal.points_d1$tweetscore))

load("../output/mod_allvars.RData")
load("../output/modcv_yhats_allvars.RData")
                   
bestmod <- mod.cv$lamlist[which.min(mod.cv$cv.err)]
                                            
message("getting yhats...")
modall_yhats <- run_heldout_yhat(
            input = pundits_agg0,
             fit = mod,
             mod.cv = mod.cv,
             lamlist = bestmod,
             outcome = ideal.points_d1$ideal.point,
             includevars = names(pundits_agg0)[-1],
            name = "allvars_heldout_yhats")